{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "22942f7e-3446-4009-b551-cca7fcc25d73",
   "metadata": {},
   "source": [
    "# Reflexion\n",
    "\n",
    "[Reflexion](https://arxiv.org/abs/2303.11366) by Shinn, et. al., is an architecture designed to learn through verbal feedback and self-reflection. The agent explicitly critiques its responses for tasks to generate a higher quality final response, at the expense of longer execution time.\n",
    "\n",
    "![reflexion diagram](./img/reflexion.png)\n",
    "\n",
    "The paper outlines 3 main components:\n",
    "\n",
    "1. Actor (agent) with self-reflection\n",
    "2. External evaluator (task-specific, e.g. code compilation steps)\n",
    "3. Episodic memory that stores the reflections from (1).\n",
    "\n",
    "In their code, the last two components are very task-specific, so in this notebook, you will build the _actor_ in LangGraph.\n",
    "\n",
    "To skip to the graph definition, see the [Construct Graph section](#Construct-Graph) below."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "906edf48-7c81-48b8-8250-fdc34043d01b",
   "metadata": {},
   "source": [
    "## 0. Prerequisites\n",
    "\n",
    "Install `langgraph` (for the framework), `langchain_openai` (for the LLM), and `langchain` + `tavily-python` (for the search engine).\n",
    "\n",
    "We will use tavily search as a tool. You can get an API key [here](https://app.tavily.com/sign-in) or replace with a different tool of your choosing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1b64a6f6-1d32-48be-92b5-66c3b04b17f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install -U --quiet  langchain langgraph langchain_openai\n",
    "# %pip install -U --quiet tavily-python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a917bb70-f84c-48e6-8d32-d14f9df2ca2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import getpass\n",
    "import os\n",
    "\n",
    "\n",
    "def _set_if_undefined(var: str) -> None:\n",
    "    if os.environ.get(var):\n",
    "        return\n",
    "    os.environ[var] = getpass.getpass(var)\n",
    "\n",
    "\n",
    "# Optional: Configure tracing to visualize and debug the agent\n",
    "_set_if_undefined(\"LANGCHAIN_API_KEY\")\n",
    "os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
    "os.environ[\"LANGCHAIN_PROJECT\"] = \"Reflexion\"\n",
    "\n",
    "_set_if_undefined(\"OPENAI_API_KEY\")\n",
    "_set_if_undefined(\"TAVILY_API_KEY\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af543598-52d0-4ec3-a05f-d2954ff793ee",
   "metadata": {},
   "source": [
    "## 1. Actor (with reflection)\n",
    "\n",
    "The main component of Reflexion is the \"actor\", which is an agent that reflects on its response and re-executes to improve based on self-critique. It's main sub-components include:\n",
    "1. Tools/tool execution\n",
    "2. Initial responder: generate an initial response (and self-reflection)\n",
    "3. Revisor: re-respond (and reflec) based on previous reflections\n",
    "\n",
    "We'll first define the tool execution context.\n",
    "\n",
    "#### Construct tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5a2ac853-b8a6-40de-b7fe-3f9f3c5ca4d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.tools.tavily_search import TavilySearchResults\n",
    "from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper\n",
    "\n",
    "search = TavilySearchAPIWrapper()\n",
    "tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "737fd31c-b4e9-47f9-a4e8-99fab340a028",
   "metadata": {},
   "source": [
    "The tools are invoked _in context_. Create a function that invokes all the requested tools."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "82f144fc-e6fa-4e4f-a8af-1a0650e79fe3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from typing import List\n",
    "\n",
    "from langchain.output_parsers.openai_tools import (\n",
    "    JsonOutputToolsParser,\n",
    "    PydanticToolsParser,\n",
    ")\n",
    "from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage\n",
    "from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation\n",
    "\n",
    "# This a helper class we have that is useful for running tools\n",
    "# It takes in an agent action and calls that tool and returns the result\n",
    "tool_executor = ToolExecutor([tavily_tool])\n",
    "# Parse the tool messages for the execution / invocation\n",
    "parser = JsonOutputToolsParser(return_id=True)\n",
    "\n",
    "\n",
    "def execute_tools(state: List[BaseMessage]) -> List[BaseMessage]:\n",
    "    tool_invocation: AIMessage = state[-1]\n",
    "    parsed_tool_calls = parser.invoke(tool_invocation)\n",
    "    ids = []\n",
    "    tool_invocations = []\n",
    "    for parsed_call in parsed_tool_calls:\n",
    "        for query in parsed_call[\"args\"][\"search_queries\"]:\n",
    "            tool_invocations.append(\n",
    "                ToolInvocation(\n",
    "                    # We only have this one for now. Would want to map it\n",
    "                    # if we change\n",
    "                    tool=\"tavily_search_results_json\",\n",
    "                    tool_input=query,\n",
    "                )\n",
    "            )\n",
    "            ids.append(parsed_call[\"id\"])\n",
    "\n",
    "    outputs = tool_executor.batch(tool_invocations)\n",
    "    outputs_map = defaultdict(dict)\n",
    "    for id_, output, invocation in zip(ids, outputs, tool_invocations):\n",
    "        outputs_map[id_][invocation.tool_input] = output\n",
    "\n",
    "    return [\n",
    "        ToolMessage(content=json.dumps(query_outputs), tool_call_id=id_)\n",
    "        for id_, query_outputs in outputs_map.items()\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "093fbaa0-9a71-4c32-9872-02a9aec9b35d",
   "metadata": {},
   "source": [
    "#### Initial responder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "5fffa8d5-068a-4f0b-adfc-b4daf30ef294",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "\n",
    "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
    "from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langsmith import traceable\n",
    "\n",
    "actor_prompt_template = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            \"\"\"You are expert researcher.\n",
    "Current time: {time}\n",
    "\n",
    "1. {first_instruction}\n",
    "2. Reflect and critique your answer. Be severe to maximize improvement.\n",
    "3. Recommend search queries to research information and improve your answer.\"\"\",\n",
    "        ),\n",
    "        MessagesPlaceholder(variable_name=\"messages\"),\n",
    "        (\"system\", \"Answer the user's question above using the required format.\"),\n",
    "    ]\n",
    ").partial(\n",
    "    time=lambda: datetime.datetime.now().isoformat(),\n",
    ")\n",
    "\n",
    "\n",
    "class Reflection(BaseModel):\n",
    "    missing: str = Field(description=\"Critique of what is missing.\")\n",
    "    superfluous: str = Field(description=\"Critique of what is superfluous\")\n",
    "\n",
    "\n",
    "class AnswerQuestion(BaseModel):\n",
    "    \"\"\"Answer the question.\"\"\"\n",
    "\n",
    "    answer: str = Field(description=\"~250 word detailed answer to the question.\")\n",
    "    reflection: Reflection = Field(description=\"Your reflection on the initial answer.\")\n",
    "    search_queries: List[str] = Field(\n",
    "        description=\"1-3 search queries for researching improvements to address the critique of your current answer.\"\n",
    "    )\n",
    "\n",
    "\n",
    "llm = ChatOpenAI(model=\"gpt-4-turbo-preview\")\n",
    "initial_answer_chain = actor_prompt_template.partial(\n",
    "    first_instruction=\"Provide a detailed ~250 word answer.\"\n",
    ") | llm.bind_tools(tools=[AnswerQuestion], tool_choice=\"AnswerQuestion\")\n",
    "validator = PydanticToolsParser(tools=[AnswerQuestion])\n",
    "\n",
    "\n",
    "class ResponderWithRetries:\n",
    "    def __init__(self, runnable, validator):\n",
    "        self.runnable = runnable\n",
    "        self.validator = validator\n",
    "\n",
    "    @traceable\n",
    "    def respond(self, state: List[BaseMessage]):\n",
    "        response = []\n",
    "        for attempt in range(3):\n",
    "            try:\n",
    "                response = self.runnable.invoke({\"messages\": state})\n",
    "                self.validator.invoke(response)\n",
    "                return response\n",
    "            except ValidationError as e:\n",
    "                state = state + [HumanMessage(content=repr(e))]\n",
    "        return response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "4a0264b8-ed2d-4f15-9d3c-085aa3a5edab",
   "metadata": {},
   "outputs": [],
   "source": [
    "first_responder = ResponderWithRetries(\n",
    "    runnable=initial_answer_chain, validator=validator\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "5922e1fe-7533-4f41-8b1d-d812707c1968",
   "metadata": {},
   "outputs": [],
   "source": [
    "example_question = \"Why is reflection useful in AI?\"\n",
    "initial = first_responder.respond([HumanMessage(content=example_question)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20901d77-0f5f-4596-90a4-412c0ac5f2c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "parsed = parser.invoke(initial)\n",
    "parsed"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4c7af31-b469-46fc-b441-0acb28515c7a",
   "metadata": {},
   "source": [
    "#### Revision\n",
    "\n",
    "The second part of the actor is a revision step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "2605fd8d-c663-446f-ba25-751190195749",
   "metadata": {},
   "outputs": [],
   "source": [
    "revise_instructions = \"\"\"Revise your previous answer using the new information.\n",
    "    - You should use the previous critique to add important information to your answer.\n",
    "        - You MUST include numerical citations in your revised answer to ensure it can be verified.\n",
    "        - Add a \"References\" section to the bottom of your answer (which does not count towards the word limit). In form of:\n",
    "            - [1] https://example.com\n",
    "            - [2] https://example.com\n",
    "    - You should use the previous critique to remove superfluous information from your answer and make SURE it is not more than 250 words.\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "# Extend the initial answer schema to include references.\n",
    "# Forcing citation in the model encourages grounded responses\n",
    "class ReviseAnswer(AnswerQuestion):\n",
    "    \"\"\"Revise your original answer to your question.\"\"\"\n",
    "\n",
    "    references: List[str] = Field(\n",
    "        description=\"Citations motivating your updated answer.\"\n",
    "    )\n",
    "\n",
    "\n",
    "revision_chain = actor_prompt_template.partial(\n",
    "    first_instruction=revise_instructions\n",
    ") | llm.bind_tools(tools=[ReviseAnswer], tool_choice=\"ReviseAnswer\")\n",
    "revision_validator = PydanticToolsParser(tools=[ReviseAnswer])\n",
    "\n",
    "revisor = ResponderWithRetries(runnable=revision_chain, validator=revision_validator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "6fd51f17-c0b0-44b6-90e2-55a66cb8f5a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "revised = revisor.respond(\n",
    "    [\n",
    "        HumanMessage(content=\"\"),\n",
    "        initial,\n",
    "        ToolMessage(\n",
    "            tool_call_id=initial.additional_kwargs[\"tool_calls\"][0][\"id\"],\n",
    "            content=json.dumps(\n",
    "                tavily_tool.invoke(str(parsed[0][\"args\"][\"search_queries\"]))\n",
    "            ),\n",
    "        ),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "28685e81-5461-47fe-bebd-2af6e552761b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'type': 'ReviseAnswer',\n",
       "  'args': {'answer': \"Reflection in AI refers to the ability of AI systems to analyze and adapt their behavior and algorithms autonomously. This introspective capability enhances AI's performance and adaptability, making it crucial for learning, transparency, and optimization. \\n\\nReflection enables AI to learn from experiences, adjusting strategies for better decision-making. For example, Google DeepMind's AI has shown significant advancements in learning and adapting strategies in various environments [1]. Moreover, AI systems can explain their decisions, supporting the development of explainable AI (XAI), vital in sensitive sectors like healthcare and autonomous driving. This increases user trust and acceptance by providing insights into AI's decision-making processes. \\n\\nAdditionally, reflection aids in debugging and improving AI models by identifying weaknesses and suggesting enhancements. For instance, AI in healthcare, like the Mayo Clinic's use of medical data analytics, demonstrates how reflective AI can optimize algorithms to provide better patient care [2]. \\n\\nIn summary, reflection in AI fosters learning and adaptation, enhances transparency and trust, and facilitates model optimization, contributing to the development of sophisticated, reliable AI systems.\",\n",
       "   'reflection': {'missing': 'The previous answer lacked specific examples and case studies to illustrate the benefits of reflection in AI. Including such examples would provide a more concrete understanding of the concept and its applications.',\n",
       "    'superfluous': 'The initial answer was comprehensive but could benefit from direct examples to demonstrate the practical applications and benefits of reflection in AI, rather than a broad overview without concrete cases.'},\n",
       "   'search_queries': ['Google DeepMind reflective AI examples',\n",
       "    'Mayo Clinic AI case study',\n",
       "    'Reflective AI benefits in healthcare'],\n",
       "   'references': ['https://casestudybuddy.com/blog/best-ai-case-study-examples/',\n",
       "    'https://indatalabs.com/blog/artificial-intelligence-case-studies']},\n",
       "  'id': 'call_0kZkZgn5DP2Z8VhRtxGkXDp5'}]"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parsed = parser.invoke(revised)\n",
    "parsed"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e623a6c9-b69b-438c-9e6e-34a8883e0623",
   "metadata": {},
   "source": [
    "## Construct Graph\n",
    "\n",
    "\n",
    "Now we can wire all our components together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "3c57318f-a30c-4dbd-9b88-f2633e8cb3b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.graph import END, MessageGraph\n",
    "\n",
    "MAX_ITERATIONS = 5\n",
    "builder = MessageGraph()\n",
    "builder.add_node(\"draft\", first_responder.respond)\n",
    "builder.add_node(\"execute_tools\", execute_tools)\n",
    "builder.add_node(\"revise\", revisor.respond)\n",
    "# draft -> execute_tools\n",
    "builder.add_edge(\"draft\", \"execute_tools\")\n",
    "# execute_tools -> revise\n",
    "builder.add_edge(\"execute_tools\", \"revise\")\n",
    "\n",
    "# Define looping logic:\n",
    "\n",
    "\n",
    "def _get_num_iterations(state: List[BaseMessage]):\n",
    "    i = 0\n",
    "    for m in state[::-1]:\n",
    "        if not isinstance(m, (ToolMessage, AIMessage)):\n",
    "            break\n",
    "        i += 1\n",
    "    return i\n",
    "\n",
    "\n",
    "def event_loop(state: List[BaseMessage]) -> str:\n",
    "    # in our case, we'll just stop after N plans\n",
    "    num_iterations = _get_num_iterations(state)\n",
    "    if num_iterations > MAX_ITERATIONS:\n",
    "        return END\n",
    "    return \"execute_tools\"\n",
    "\n",
    "\n",
    "# revise -> execute_tools OR end\n",
    "builder.add_conditional_edges(\"revise\", event_loop)\n",
    "builder.set_entry_point(\"draft\")\n",
    "graph = builder.compile()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "2634a3ea-7423-4579-9f4e-390e439c3209",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "## 1. draft\n",
      "content='' additional_kwargs={'tool_calls': [{'id': 'call_GOmUTyAeA8kLLm4G9sXZ4jGV', 'function': {'a ...\n",
      "---\n",
      "## 2. execute_tools\n",
      "[ToolMessage(content='{\"successful climate policies examples\": [{\"url\": \"https://www.washingtonpost. ...\n",
      "---\n",
      "## 3. revise\n",
      "content='' additional_kwargs={'tool_calls': [{'id': 'call_Z0dky70zr74bLi6dBfQTCj6j', 'function': {'a ...\n",
      "---\n",
      "## 4. execute_tools\n",
      "[ToolMessage(content='{\"successful climate policies examples\": [{\"url\": \"https://www.washingtonpost. ...\n",
      "---\n",
      "## 5. revise\n",
      "content='' additional_kwargs={'tool_calls': [{'id': 'call_tM6DgVvQoux8IDIkRlrUKwOj', 'function': {'a ...\n",
      "---\n",
      "## 6. execute_tools\n",
      "[ToolMessage(content='{\"successful climate policies examples\": [{\"url\": \"https://www.washingtonpost. ...\n",
      "---\n",
      "## 7. revise\n",
      "content='' additional_kwargs={'tool_calls': [{'id': 'call_XkarDDuEf43cOBPn9zNXN8vM', 'function': {'a ...\n",
      "---\n",
      "## 8. __end__\n",
      "[HumanMessage(content='How should we handle the climate crisis?'), AIMessage(content='', additional_ ...\n",
      "---\n"
     ]
    }
   ],
   "source": [
    "events = graph.stream(\n",
    "    [HumanMessage(content=\"How should we handle the climate crisis?\")]\n",
    ")\n",
    "for i, step in enumerate(events):\n",
    "    node, output = next(iter(step.items()))\n",
    "    print(f\"## {i+1}. {node}\")\n",
    "    print(str(output)[:100] + \" ...\")\n",
    "    print(\"---\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "c9195436-aed9-4356-948b-2ca081a6d0bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Addressing the climate crisis requires a comprehensive approach, combining policy, technology, and finance. Successful policy initiatives include the U.S. Army's carbon footprint reduction, federal funding to plug methane-leaking wells, and Ithaca, NY's building decarbonization [1]. Renewable energy, particularly solar and wind, is forecasted to surpass coal by 2025, demonstrating the critical role of transitioning to sustainable energy sources [2]. Technological innovations are essential, with significant advancements in solar cell efficiency, data-driven climate adaptation technologies, and efforts to replace or mitigate major emission sources [3][4][5]. Financial mechanisms are pivotal, with climate finance needing a substantial increase to meet global warming limits. The U.S. has made progress by significantly enhancing its international public climate finance, exemplifying financial commitment to supporting global climate action [6]. This holistic strategy, integrating policy, technological innovation, and finance, represents the most effective way to tackle the climate crisis, emphasizing sustainability, innovation, and global cooperation.\n",
      "\n",
      "References:\n",
      "[1] https://www.washingtonpost.com/climate-solutions/2022/04/21/climate-change-policy-examples-list/\n",
      "[2] https://www.weforum.org/agenda/2024/01/climate-transition-tipping-point/\n",
      "[3] https://www.technologyreview.com/2024/01/11/1086412/three-climate-technologies-breaking-through-in-2024/\n",
      "[4] https://unfccc.int/news/how-climate-technology-is-being-ramped-up\n",
      "[5] https://www.weforum.org/agenda/2024/02/ai-climate-adaptation-technologies/\n",
      "[6] https://www.state.gov/progress-report-on-president-bidens-climate-finance-pledge/\n"
     ]
    }
   ],
   "source": [
    "print(parser.invoke(step[END][-1])[0][\"args\"][\"answer\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7159e30c-728e-480d-8252-915404cc756d",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "Congrats on building a Reflexion actor! I'll leave you with a few observations to save you some time when choosing which parts of this agent ot adapt to your workflow:\n",
    "1. This agent trades off execution time for quality. It explicitly forces the agent to critique and revise the output over several steps, which usually (not always) increases the response quality but takes much longer to return a final answer\n",
    "2. The 'reflections' can be paired with additional external feedback (such as validators), to further guide the actor.\n",
    "3. In the paper, 1 environment (AlfWorld) uses external memory. It does this by storing summaries of the reflections to an external store and using them in subsequent trials/invocations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "140c1961-64f6-41f1-9b80-f09deffae21f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
